from PyQt5 import QtCore
from PyQt5.QtCore import QEvent

from structures.abs_structure_args import *
from structures.dog_leg_args import DogLegAbsStructureArgs
from structures.plot_clock_args import PlotClockStructureArgs
from utils.math_tool import *
from utils.math_tool import rotate


class BasicStructure(AbsStructureArgs):

    def __init__(self):
        # 屏幕坐标和物理单位的比例
        self._ratio = 3

        # 固定杆在左边还是右边
        self.solid_link_side = LinkSide.LEFT
        # self.solid_link_side = LinkSide.RIGHT

        self.capture_key_point = KeyPoint.POINT_NONE
        self.refer_vector_x_negative = np.array([-1, 0], dtype=np.double)
        self.refer_vector_x_positive = np.array([1, 0], dtype=np.double)

        # radians
        self.alpha_radians = 0
        self.beta_radians = 0
        # DA和DE的夹角
        self.theta_degrees = 0

        # 拖拽牵引线
        self.drag_line = None
        self.point_a = None
        self.point_b = None
        self.point_c = None
        self.point_d = None
        self.point_e = None
        self.point_f = None

        # 保存所有轨迹的列表
        self.trajectory_list = []
        # 遍历轨迹
        self.walks_trajectory = None
        # 路径点
        self.path_trajectory = None

        self.links_visible = {
            LINK_AB: True,
            LINK_CD: True,
            LINK_BE: False,
            LINK_DE: False,
            LINK_EF: False,
        }

        self.ALL_STRUCTURES = [
            DogLegAbsStructureArgs(),
            PlotClockStructureArgs(),
        ]

        self.current_structure: AbsStructureArgs = self.ALL_STRUCTURES[0]

        self.links = self._get_init_links()
        self.__init_structure_args()
        self.__update_link_end()
        self.__try_update_key_points(KeyPoint.POINT_END)

    def _get_init_links(self):
        links = self.current_structure._get_init_links()
        links = {key: self.forward_ratio(value) for key, value in links.items()}
        return links

    def _get_init_point_end(self):
        point_end = self.current_structure._get_init_point_end()
        return self.forward_ratio(point_end)

    def _get_init_theta_degrees(self):
        return self.current_structure._get_init_theta_degrees()

    def is_link_visible(self, link_name):
        if link_name not in self.links_visible:
            return False

        return self.links_visible[link_name]

    def __init_structure_args(self):
        self.theta_degrees = self._get_init_theta_degrees()
        self.point_a = np.array([- self.solid_link_side.value * self.links[ORG_LINK_AC] / 2, 0], dtype=np.double)
        self.point_c = np.array([self.solid_link_side.value * self.links[ORG_LINK_AC] / 2, 0], dtype=np.double)
        self.point_f = np.array(self._get_init_point_end(), dtype=np.double)

    def reset_links(self):
        self.links = self._get_init_links()
        self.__init_structure_args()
        self.__update_link_end()
        self.__try_update_key_points(KeyPoint.POINT_END)

    def set_mode(self, mode):
        self.current_structure = self.ALL_STRUCTURES[mode]
        self.reset_links()

    def get_link_value(self, link_name):
        return self.inverse_ratio(self.links[link_name])

    def get_angle_degrees_value(self, angle_name):
        if angle_name == ANGLE_ALPHA:
            return np.rad2deg(self.alpha_radians)
        elif angle_name == ANGLE_BETA:
            return -np.rad2deg(self.beta_radians)
        elif angle_name == ANGLE_THETA:
            return self.theta_degrees

        return 0

    def update_angle_value(self, angle_name, degrees):
        if angle_name not in [ANGLE_ALPHA, ANGLE_BETA, ANGLE_THETA]:
            return False

        if angle_name == ANGLE_ALPHA:
            return self.try_update_alpha(degrees)
        elif angle_name == ANGLE_BETA:
            return self.try_update_beta(degrees)
        elif angle_name == ANGLE_THETA:
            return self.update_theta(degrees)

        return False

    def update_theta(self, degrees):
        self.theta_degrees = degrees
        self.__update_link_end()
        self.__try_update_key_points(KeyPoint.POINT_LEFT)

    def update_link_value(self, link_name, value):
        value = self.forward_ratio(value)
        if link_name in self.links:
            old_value = self.links[link_name]
            self.links[link_name] = value

            old_link_bf = None
            old_point_a, old_point_c = None, None

            if link_name == LINK_BE or link_name == LINK_EF:
                old_link_bf = self.links[LINK_BF]
                self.__update_link_end()
            elif link_name == ORG_LINK_AC:
                old_point_a = self.point_a
                old_point_c = self.point_c
                self.point_a = np.array([-self.solid_link_side.value * self.links[ORG_LINK_AC] / 2, 0], dtype=np.double)
                self.point_c = np.array([self.solid_link_side.value * self.links[ORG_LINK_AC] / 2, 0], dtype=np.double)

            if self.__try_update_key_points(KeyPoint.POINT_END):
                return True

            # 更新失败，还原

            self.links[link_name] = old_value
            if old_link_bf is not None:
                self.links[LINK_BF] = old_link_bf

            if old_point_a is not None:
                self.point_a = old_point_a
            if old_point_c is not None:
                self.point_c = old_point_c

        return False

    def __update_link_end(self):
        self.links[LINK_BF] = np.sqrt(
            self.links[LINK_BE] ** 2 +
            self.links[LINK_EF] ** 2 -
            2 * self.links[LINK_BE] * self.links[LINK_EF] * np.cos(np.deg2rad(self.theta_degrees))
        )

    def __try_update_key_points(self, key_point: KeyPoint = KeyPoint.POINT_NONE, temp_point=None):
        """
        尝试更新关键点
        :param key_point:  要固定的点
        :param temp_point: 该固定的点的临时值
        :return: 该固定点可以更新成功返回True，否则返回False
        """

        if key_point == KeyPoint.POINT_NONE:
            return False

        if key_point == KeyPoint.POINT_LEFT or key_point == KeyPoint.POINT_RIGHT or key_point == KeyPoint.POINT_FORWARD:
            # 正解
            kinematics = self.__forward_kinematics(key_point, temp_point)
            if kinematics is None:
                return False
            self.point_e, self.point_f = kinematics
        else:
            # 逆解
            kinematics = self.__inverse_kinematics(temp_point)
            if kinematics is None:
                return False
            self.point_b, self.point_d, self.point_e = kinematics

        # 根据LinkSide计算夹角
        if self.is_left_mode():
            self.alpha_radians = calc_radians_between_vector(self.point_b - self.point_a, self.refer_vector_x_negative)
            self.beta_radians = calc_radians_between_vector(self.point_d - self.point_c, self.refer_vector_x_positive)
        else:
            self.alpha_radians = -calc_radians_between_vector(self.point_b - self.point_a, self.refer_vector_x_positive)
            self.beta_radians = -calc_radians_between_vector(self.point_d - self.point_c, self.refer_vector_x_negative)

        return True

    def set_left_mode(self, is_left_mode):
        self.solid_link_side = LinkSide.LEFT if is_left_mode else LinkSide.RIGHT
        self.reset_links()

    def is_left_mode(self):
        return self.solid_link_side == LinkSide.LEFT

    def inverse_not_update(self, end_point):
        """
        逆解，得到两个关节角度，但不更新关键点
        :param end_point: 末端坐标
        :return: 左右关节角度（都是与x负轴的夹角 角度degrees），如果逆解失败，返回None
        """
        if not isinstance(end_point, np.ndarray):
            end_point = np.array(end_point)

        kinematics = self.__inverse_kinematics(end_point)
        if kinematics is None:
            # 逆解失败
            return None

        point_b, point_d, point_e = kinematics
        if self.is_left_mode():
            alpha_radians = calc_radians_between_vector(point_b - self.point_a, self.refer_vector_x_negative)
            beta_radians = calc_radians_between_vector(point_d - self.point_c, self.refer_vector_x_positive)
        else:
            alpha_radians = -calc_radians_between_vector(self.point_b - self.point_a, self.refer_vector_x_positive)
            beta_radians = -calc_radians_between_vector(self.point_d - self.point_c, self.refer_vector_x_negative)

        return np.rad2deg(alpha_radians), -np.rad2deg(beta_radians)

    def __forward_kinematics(self, key_point, temp_point):
        if key_point == KeyPoint.POINT_FORWARD:
            point_b, point_d = temp_point
        else:
            point_b = temp_point if (key_point == KeyPoint.POINT_LEFT and temp_point is not None) else self.point_b
            point_d = temp_point if (key_point == KeyPoint.POINT_RIGHT and temp_point is not None) else self.point_d

        # 计算BE和DE的交点E
        intersection_cross_points = calculate_intersection(point_b, point_d, self.links[LINK_BE], self.links[LINK_DE])
        if intersection_cross_points is None:
            # print("------------求D交点失败！")
            return

        point_e = intersection_cross_points[np.argmax(intersection_cross_points[:, 1])]

        # 计算BF和EF的交点F
        intersection_end_points = calculate_intersection(point_b, point_e, self.links[LINK_BF], self.links[LINK_EF])
        if intersection_end_points is None:
            print("------------求E交点失败！")
            return

        # 向量BF要在向量BE的顺时针位(BE方向的右侧）
        vec_eb = point_e - point_b
        points_f_radians = [calc_radians_between_vector(vec_eb, p_f - point_b) for p_f in intersection_end_points]
        # point_f = intersection_end_points[np.argmax(intersection_end_points[:, 1])]
        if self.is_left_mode():
            point_f = intersection_end_points[np.argmax(points_f_radians)]
        else:
            point_f = intersection_end_points[np.argmin(points_f_radians)]

        return point_e, point_f

    def __inverse_kinematics(self, temp_point_end):
        if temp_point_end is None:
            temp_point_end = self.point_f

        # 计算左交点(根据A-B圆和F-B圆的交点)
        intersection_left_points = calculate_intersection(self.point_a, temp_point_end, self.links[LINK_AB],
                                                          self.links[LINK_BF])
        if intersection_left_points is None:
            print("------------求左交点失败！")
            return

        if self.is_left_mode():
            arg_index = np.argmin(intersection_left_points[:, 0])
        else:
            arg_index = np.argmax(intersection_left_points[:, 0])

        point_b = intersection_left_points[arg_index]

        # 计算点E位置(根据B - E圆和F - E圆的交点E)
        e_points = calculate_intersection(point_b, temp_point_end, self.links[LINK_BE], self.links[LINK_EF])
        if e_points is None:
            print("------------求交点E失败！")
            return

        # 向量BF要在向量BE的顺时针位(BE方向的右侧）
        vec_bf = temp_point_end - point_b
        points_d_radians = [calc_radians_between_vector(vec_bf, point_d - point_b) for point_d in e_points]

        if self.is_left_mode():
            point_e = e_points[np.argmin(points_d_radians)]
        else:
            point_e = e_points[np.argmax(points_d_radians)]

        # 计算右交点
        intersection_right_points = calculate_intersection(point_e, self.point_c, self.links[LINK_DE],
                                                           self.links[LINK_CD])
        if intersection_right_points is None:
            print("------------求右交点失败！")
            return

        if self.is_left_mode():
            arg_index = np.argmax(intersection_right_points[:, 0])
        else:
            arg_index = np.argmin(intersection_right_points[:, 0])

        point_d = intersection_right_points[arg_index]

        # print("point_b: ", point_b)
        # print("point_e: ", point_e)
        # print("point_e: ", point_e)
        return point_b, point_d, point_e

    def is_trigger_target(self, p_x, p_y, x, y):
        return (p_x - x) ** 2 + (p_y - y) ** 2 < 8 ** 2

    def handleMouseEvent(self, type, button, p: np.ndarray, update_cb):

        x, y = p
        # print("mouse: {} pos: {} e: {}".format(type, (x, y), self.point_f))

        if type == QEvent.MouseButtonPress:
            if button != QtCore.Qt.MouseButton.LeftButton:
                return

            if self.is_trigger_target(*self.point_b, x, y):
                self.capture_key_point = KeyPoint.POINT_LEFT
            elif self.is_trigger_target(*self.point_d, x, y):
                self.capture_key_point = KeyPoint.POINT_RIGHT
            elif self.is_trigger_target(*self.point_f, x, y):
                self.capture_key_point = KeyPoint.POINT_END

            print("---------------------capture: ", self.capture_key_point)

        elif type == QEvent.MouseMove:
            if self.capture_key_point == KeyPoint.POINT_NONE:
                return

            # print("---------------------move:", (x, y))
            if self.capture_key_point == KeyPoint.POINT_LEFT:
                self.drag_line = (self.point_a, p)
                vec_op = p - self.point_a
                vec_oa = self.point_b - self.point_a
                point = (np.linalg.norm(vec_oa) / np.linalg.norm(vec_op)) * vec_op + self.point_a
                if self.__try_update_key_points(self.capture_key_point, point):
                    self.point_b = point
                    update_cb()
            elif self.capture_key_point == KeyPoint.POINT_RIGHT:
                self.drag_line = (self.point_c, p)
                vec_bp = p - self.point_c
                vec_bc = self.point_d - self.point_c
                point = (np.linalg.norm(vec_bc) / np.linalg.norm(vec_bp)) * vec_bp + self.point_c
                if self.__try_update_key_points(self.capture_key_point, point):
                    self.point_d = point
                    update_cb()
            elif self.capture_key_point == KeyPoint.POINT_END:
                if self.__try_update_key_points(self.capture_key_point, p):
                    self.point_f = p
                    update_cb()

        elif type == QEvent.MouseButtonRelease:
            self.capture_key_point = KeyPoint.POINT_NONE
            self.drag_line = None
            update_cb()

    def try_update_point_end(self, end_point):
        if not isinstance(end_point, np.ndarray):
            end_point = np.array(end_point)

        if self.__try_update_key_points(KeyPoint.POINT_END, end_point):
            self.point_f = end_point
            return True
        return False

    def try_update_beta(self, new_beta_degrees):
        # new_beta_degrees = np.clip(new_beta_degrees, -30, 150)
        new_beta = np.radians(new_beta_degrees)
        org_vec_cd = self.links[LINK_CD] * (self.solid_link_side.value * self.refer_vector_x_positive)
        new_vec_od = self.point_c + rotate(new_beta, org_vec_cd)
        # print("old: {} new:{}".format(self.point_d, new_vec_od))
        if self.__try_update_key_points(KeyPoint.POINT_RIGHT, new_vec_od):
            self.point_d = new_vec_od
            # 如果点位合法，则更新点位并绘图
            return True

        return False

    def try_update_alpha(self, new_alpha_degrees):
        # 将 new_alpha_degrees 限制在 -30 到 150 度
        # new_alpha_degrees = np.clip(new_alpha_degrees, -30, 150)
        print("new_alpha_degrees: ", new_alpha_degrees)

        new_alpha = -np.radians(new_alpha_degrees)
        org_vec_ab = self.links[LINK_AB] * (self.solid_link_side.value * self.refer_vector_x_negative)
        new_vec_ob = self.point_a + rotate(new_alpha, org_vec_ab)
        # print("old: {} new:{}".format(self.point_b, new_vec_ob))
        if self.__try_update_key_points(KeyPoint.POINT_LEFT, new_vec_ob):
            self.point_b = new_vec_ob
            # 如果点位合法，则更新点位并绘图
            return True

        return False

    def try_update_forward(self, new_forward_alpha_degrees, new_forward_beta_degrees):
        new_alpha = -np.radians(new_forward_alpha_degrees)
        org_vec_ab = self.links[LINK_AB] * (self.solid_link_side.value * self.refer_vector_x_negative)
        new_vec_ob = self.point_a + rotate(new_alpha, org_vec_ab)

        new_beta = np.radians(new_forward_beta_degrees)
        org_vec_cd = self.links[LINK_CD] * (self.solid_link_side.value * self.refer_vector_x_positive)
        new_vec_od = self.point_c + rotate(new_beta, org_vec_cd)

        if self.__try_update_key_points(KeyPoint.POINT_FORWARD, (new_vec_ob, new_vec_od)):
            # 如果点位合法，则更新点位并绘图
            self.point_b = new_vec_ob
            self.point_d = new_vec_od
            return True

        return False


    def inverse_ratio(self, value):
        if isinstance(value, list) or isinstance(value, tuple):
            return [v / self._ratio for v in value]
        elif isinstance(np.ndarray, type(value)):
            return value / self._ratio

        return value / self._ratio

    def forward_ratio(self, value):
        if isinstance(value, list) or isinstance(value, tuple):
            return [v * self._ratio for v in value]
        elif isinstance(np.ndarray, type(value)):
            return value * self._ratio

        return value * self._ratio

    def set_walks_trajectory(self, trajectory):
        self.walks_trajectory = trajectory

    def get_walks_trajectory(self):
        return self.walks_trajectory

    def get_trajectory_list(self) -> list:
        return self.trajectory_list

    def get_current_end_point(self):
        return self.point_f